Generative Adversarial Networks (GAN)
Table of Contents
Source
If $P_{\text{model}}(x)$ can be estimated as close to $P_{\text{data}}(x)$, then data can be generated by sampling from $P_{\text{model}}(x)$.
In generative modeling, we'd like to train a network that models a distribution, such as a distribution over images.
GANs do not work with any explicit density function !
Instead, take game-theoretic approach
One way to judge the quality of the model is to sample from it.
Model to produce samples which are indistinguishable from the real data, as judged by a discriminator network whose job is to tell real from fake
$$\text{loss} = -y \log h(x) - (1-y) \log (1-h(x))$$
Non-Saturating Game when the generator is trained
Early in learning, when $G$ is poor, $D$ can reject samples with high confidence because they are clearly different from the training data. In this case, $\log(1-D(G(z)))$ saturates.
Step 1: Fix $G$ and perform a gradient step to
Step 2: Fix $D$ and perform a gradient step to
OR
Step 1: Fix $G$ and perform a gradient step to
Step 2: Fix $D$ and perform a gradient step to
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
(train_x, train_y), _ = tf.keras.datasets.mnist.load_data()
train_x = train_x[np.where(train_y == 2)]
train_x = train_x/255.0
train_x = train_x.reshape(-1, 784)
print('train_iamges :', train_x.shape)
generator = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, activation = 'relu', input_dim = 100),
tf.keras.layers.Dense(units = 784, activation = 'sigmoid')
])
discriminator = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, activation = 'relu', input_dim = 784),
tf.keras.layers.Dense(units = 1, activation = 'sigmoid'),
])
discriminator.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0001),
loss = 'binary_crossentropy')
discriminator.trainable = False
combined_input = tf.keras.layers.Input(shape = (100,))
generated = generator(combined_input)
combined_output = discriminator(generated)
combined = tf.keras.models.Model(inputs = combined_input, outputs = combined_output)
combined.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002),
loss = 'binary_crossentropy')
def make_noise(samples):
return np.random.normal(0, 1, [samples, 100])
def plot_generated_images(generator, samples = 3):
noise = make_noise(samples)
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(samples, 28, 28)
for i in range(samples):
plt.subplot(1, samples, i+1)
plt.imshow(generated_images[i], 'gray', interpolation = 'nearest')
plt.axis('off')
plt.tight_layout()
plt.show()
Step 1: Fix $G$ and perform a gradient step to
Step 2: Fix $D$ and perform a gradient step to
n_iter = 20000
batch_size = 100
fake = np.zeros(batch_size)
real = np.ones(batch_size)
for i in range(n_iter):
# Train Discriminator
noise = make_noise(batch_size)
generated_images = generator.predict(noise, verbose = 0)
idx = np.random.randint(0, train_x.shape[0], batch_size)
real_images = train_x[idx]
D_loss_real = discriminator.train_on_batch(real_images, real)
D_loss_fake = discriminator.train_on_batch(generated_images, fake)
D_loss = D_loss_real + D_loss_fake
# Train Generator
noise = make_noise(batch_size)
G_loss = combined.train_on_batch(noise, real)
if i % 5000 == 0:
print('Discriminator Loss: ', D_loss)
print('Generator Loss: ', G_loss)
plot_generated_images(generator)
plot_generated_images(generator)
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
train_x, test_x = train_x/255.0 , test_x/255.0
train_x, test_x = train_x.reshape(-1,784), test_x.reshape(-1,784)
train_y = tf.keras.utils.to_categorical(train_y, num_classes = 10)
test_y = tf.keras.utils.to_categorical(test_y, num_classes = 10)
print('train_x: ', train_x.shape)
print('test_x: ', test_x.shape)
print('train_y: ', train_y.shape)
print('test_y: ', test_y.shape)
generator_model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, activation = 'relu', input_dim = 138),
tf.keras.layers.Dense(units = 784, activation = 'sigmoid')
])
noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (10,))
model_input = tf.keras.layers.concatenate([noise, label], axis = 1)
generated_image = generator_model(model_input)
generator = tf.keras.models.Model(inputs = [noise, label], outputs = generated_image)
generator.summary()
discriminator_model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, activation = 'relu', input_dim = 794),
tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
])
input_image = tf.keras.layers.Input(shape = (784,))
label = tf.keras.layers.Input(shape = (10,))
model_input = tf.keras.layers.concatenate([input_image, label], axis = 1)
validity = discriminator_model(model_input)
discriminator = tf.keras.models.Model(inputs = [input_image, label], outputs = validity)
discriminator.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002),
loss = ['binary_crossentropy'])
discriminator.summary()
discriminator.trainable = False
noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (10,))
generated_image = generator([noise, label])
validity = discriminator([generated_image, label])
combined = tf.keras.models.Model(inputs = [noise, label], outputs = validity)
combined.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002),
loss = ['binary_crossentropy'])
combined.summary()
def create_noise(samples):
return np.random.normal(0, 1, [samples, 128])
def plot_generated_images(generator):
noise = create_noise(10)
label = np.arange(0, 10).reshape(-1, 1)
label_onehot = np.eye(10)[label.reshape(-1)]
generated_images = generator.predict([noise, label_onehot])
plt.figure(figsize = (12, 3))
for i in range(generated_images.shape[0]):
plt.subplot(1, 10, i + 1)
plt.imshow(generated_images[i].reshape((28, 28)), 'gray', interpolation = 'nearest')
plt.title('Digit: {}'.format(i))
plt.axis('off')
plt.show()
n_iter = 30000
batch_size = 50
valid = np.ones(batch_size)
fake = np.zeros(batch_size)
for i in range(n_iter):
# Train Discriminator
idx = np.random.randint(0, train_x.shape[0], batch_size)
real_images, labels = train_x[idx], train_y[idx]
noise = create_noise(batch_size)
generated_images = generator.predict([noise,labels], verbose = 0)
d_loss_real = discriminator.train_on_batch([real_images, labels], valid)
d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake)
d_loss = d_loss_real + d_loss_fake
# Train Generator
noise = create_noise(batch_size)
labels = np.random.randint(0, 10, batch_size)
labels_onehot = np.eye(10)[labels]
g_loss = combined.train_on_batch([noise, labels_onehot], valid)
if i % 5000 == 0:
print('Discriminator Loss: ', d_loss)
print('Generator Loss: ', g_loss)
plot_generated_images(generator)
In a standard generative model, there is no control on the features of the data being generated.
In the Information Maximizing GAN (InfoGAN), the generator learns to generate a fake sample with latent codes (such as values in the range of -1 to 1) that has interpretable information of the data rather than a generic sample from unknown noise distribution.
The latent code in InfoGAN learns interpretable information from the data using unsupervised learning.
For instance, MNIST digits generated by latent code variation
Simple modification to the original GAN framework, the latent code c is input to the generator and the added Q Net predicts the latent code c of a fake sample x_fake.
The generative model learns interpretable information from the data by itself.
Generator at Conditional GAN
Generator at InfoGAN
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
(train_x, train_y), _ = tf.keras.datasets.mnist.load_data()
train_x = train_x[np.where(train_y == 2)]
train_x = train_x/255.0
train_x = train_x.reshape(-1, 28, 28, 1)
print('train_iamges :', train_x.shape)
generator = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 1024,
use_bias = False,
input_shape = (62 + 2,)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(units = 7*7*128,
use_bias = False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Reshape((7, 7, 128)),
tf.keras.layers.Conv2DTranspose(64,
(4, 4),
strides = (2, 2),
padding = 'same',
use_bias = False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.Conv2DTranspose(1,
(4, 4),
strides = (2, 2),
padding = 'same',
use_bias = False,
activation = 'sigmoid')
])
extractor = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(64,
(4, 4),
strides = (2, 2),
padding = 'same',
use_bias = False,
input_shape = [28, 28, 1]),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Conv2D(128,
(4, 4),
strides = (2, 2),
padding = 'same',
use_bias = False),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units = 1024,
use_bias = False),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.LeakyReLU()
])
d_network = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 1,
input_shape = (1024,),
use_bias = False,
activation = 'sigmoid')
])
q_network = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 128,
use_bias = False,
input_shape = (1024,)),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dense(units = 2,
use_bias = False)
])
combined_input = tf.keras.layers.Input(shape = (28, 28, 1))
combined_feature = extractor(combined_input)
combined_output = d_network(combined_feature)
discriminator = tf.keras.models.Model(inputs = combined_input,
outputs = combined_output)
discriminator.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4),
loss = 'binary_crossentropy')
extractor.trainable = False
d_network.trainable = False
combined_input = tf.keras.layers.Input(shape = (62 + 2,))
generated = generator(combined_input)
combined_feature = extractor(generated)
combined_output = d_network(combined_feature)
combined_d = tf.keras.models.Model(inputs = combined_input,
outputs = combined_output)
combined_d.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-3),
loss = 'binary_crossentropy')
extractor.trainable = False
d_network.trainable = False
combined_input = tf.keras.layers.Input(shape = (62 + 2,))
generated = generator(combined_input)
combined_feature = extractor(generated)
combined_latent = q_network(combined_feature)
combined_q = tf.keras.models.Model(inputs = combined_input,
outputs = combined_latent)
combined_q.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-3),
loss = 'mean_squared_error')
def make_noise(samples):
return np.random.uniform(-1, 1, size = [samples, 62])
def make_code(samples):
return 2*np.random.rand(samples, 2) - 1
def plot_generated_images(generator):
z = np.random.randn(1, 62).repeat(5, axis = 0)
c = np.stack([np.linspace(-1, 1, 5), np.zeros(5)]).T
noise = np.concatenate([z, c], -1)
generated_images = generator.predict(noise, verbose = 0)
generated_images = generated_images.reshape(5, 28, 28)
print('')
print('Continuous Latent Code 1')
for i in range(5):
plt.subplot(1, 5, i+1)
plt.imshow(generated_images[i], 'gray', interpolation = 'nearest')
plt.axis('off')
plt.tight_layout()
plt.show()
z = np.random.randn(1, 62).repeat(5, axis = 0)
c = np.stack([np.zeros(5), np.linspace(-1, 1, 5)]).T
noise = np.concatenate([z, c], -1)
generated_images = generator.predict(noise, verbose = 0)
generated_images = generated_images.reshape(5, 28, 28)
print('Continuous Latent Code 2')
for i in range(5):
plt.subplot(1, 5, i+1)
plt.imshow(generated_images[i], 'gray', interpolation = 'nearest')
plt.axis('off')
plt.tight_layout()
plt.show()
print('')
n_iter = 5000
batch_size = 256
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for i in range(n_iter):
# Train Discriminator
for _ in range(2):
z = make_noise(batch_size)
c = make_code(batch_size)
noise = np.concatenate([z, c], -1)
generated_images = generator.predict(noise, verbose = 0)
idx = np.random.choice(len(train_x), batch_size, replace = False)
real_images = train_x[idx]
D_loss_real = discriminator.train_on_batch(real_images, real)
D_loss_fake = discriminator.train_on_batch(generated_images, fake)
D_loss = D_loss_real + D_loss_fake
# Train Generator & Q Net
for _ in range(1):
z = make_noise(batch_size)
c = make_code(batch_size)
noise = np.concatenate([z, c], -1)
G_loss = combined_d.train_on_batch(noise, real)
Q_loss = combined_q.train_on_batch(noise, c)
# Print Loss
if (i + 1) % 500 == 0:
print('Epoch: {:5d} | Discriminator Loss: {:.3f} | Generator Loss: {:.3f} | Q Net Loss: {:.3f}'.format(i + 1, D_loss, G_loss, Q_loss))
plot_generated_images(generator)
images_save_1 = []
images_save_2 = []
for i in range(8):
z = np.random.randn(1, 62).repeat(8, axis = 0)
# Continuous Latent Code 1
c = np.stack([np.linspace(-1, 1, 8), np.zeros(8)]).T
noise = np.concatenate([z, c], -1)
generated_images = generator.predict(noise, verbose=0)
generated_images = generated_images.reshape(8, 28, 28)
images_save_1.append(generated_images)
# Continuous Latent Code 2
c = np.stack([np.zeros(8), np.linspace(-1, 1, 8)]).T
noise = np.concatenate([z, c], -1)
generated_images = generator.predict(noise, verbose=0)
generated_images = generated_images.reshape(8, 28, 28)
images_save_2.append(generated_images)
print('Continuous Latent Code 1')
fig, ax = plt.subplots(8, 8, figsize = (10, 10))
for i in range(8):
for j in range(8):
ax[i][j].imshow(images_save_1[i][j], 'gray')
ax[i][j].set_xticks([])
ax[i][j].set_yticks([])
plt.show()
print('Continuous Latent Code 2')
fig, ax = plt.subplots(8, 8, figsize = (10, 10))
for i in range(8):
for j in range(8):
ax[i][j].imshow(images_save_2[i][j], 'gray')
ax[i][j].set_xticks([])
ax[i][j].set_yticks([])
plt.show()
Ian Goodfellow, et al., "Generative Adversarial Nets" NIPS, 2014.
At NIPS 2016 by Ian Goodfellow
%%html
<center><iframe src="https://www.youtube.com/embed/9JpdAg6uMXs?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%html
<center><iframe src="https://www.youtube.com/embed/5WoItGTWV54?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
MIT by Aaron Courville
%%html
<center><iframe src="https://www.youtube.com/embed/JVb54xhEw6Y?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%html
<center><iframe src="https://www.youtube.com/embed/odpjk7_tGY0?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')